Session 1. End to End ML

Introduction

According to the National Heart, Lung and Blood Institute:

Heart disease is a catch-all phrase for a variety of conditions that affect the heart’s structure and function. Coronary heart disease is a type of heart disease that develops when the arteries of the heart cannot deliver enough oxygen-rich blood to the heart. It is the leading cause of death in the United States.

(Emphasis by me. Source: https://www.nhlbi.nih.gov/health-topics/espanol/enfermedad-coronaria)

Also, according to the World Health Organization, cardiovascular diseases are the leading cause of death globally (source: https://www.who.int/news-room/fact-sheets/detail/cardiovascular-diseases-(cvds)).

In this notebook we try to learn enough information of this topic to understand the Heart Disease UCI dataset and build simple models to predict whether a patient has a disease or not based on features like the heart rate during exercise or the cholesterol levels in the blood.

Blood and heart

Blood is very important to ensure the proper functioning of the body. Its functions cover the transport of oxygen and nutrients to the cells of the body as well as the removal of the cellular waste products.

Blood is transported to the rest of the body because it is pumped by the heart. This organ receives oxygen-poor blood and sends it to the lungs to oxygenate it. And sends the oxygen-rich blood that comes from the lugns to the rest of the body.

Blood flow through the chambers of the heart

By josiño - Own work, Public Domain, https://commons.wikimedia.org/w/index.php?curid=9396374. Flow of the blood through the chambers of the heart. Blue arrows represent oxygen-poor blood received from the rest of the body and sent to the lungs. Red arrows represent oxygen-rich blood coming from the lungs that is sent to the rest of the body.

An inadequate supply of the blood can yield the cells to not get enough energy to function properly, causing the death of the cells in the worst case.

Coronary heart disease

The heart also needs oxygen and nutrients to function properly, these come through arteries known as coronary arteries. When we talk about a coronary disease, we often mean a difficulty of the blood flow in these arteries due to the accumulation of substances on their walls.

Death of heart cells due to an ischemia in the coronary arteries

By NIH: National Heart, Lung and Blood Institute - http://www.nhlbi.nih.gov/health/health-topics/topics/heartattack/, Public Domain, https://commons.wikimedia.org/w/index.php?curid=25287085. Death of heart cells due to an ischemia in the coronary arteries.

In the worst case, the impact of leaving the cells of the heart without nutrients and oxygen is a heart attack, in other words, the death of part of the heart cells. This, in turn, would have an impact on the rest of the body because the pumping of the heart would be affected.

Glossary of terms

  • Atherosclerosis: accumulation of substances on the walls of arteries which can hinder the blood flow. Moreover, the rupture of this plaque of substances can cause the formation of a blood clot (thrombus) that, in turn, can block even more the affected area or go to other parts of the body and block those parts (embolism). (Sources: American Heart Association, Mayo Clinic)

  • Ischemia: blood flow reduction to a tissue. This implies a reduction of the supply of oxygen and nutrients, so cells won’t get enough energy to function properly. (Sources: American Heart Association, Mayo Clinic, Wikipedia)

  • Angina: chest pain due to a blood flow reduction in the coronary arteries. (Sources: United Kingdom National Health Service, (Spanish) Video sobre angina de Alberto Sanagustín)

  • Stable angina: angina caused by situations that requires oxygen (for example, exercise or stress) and it goes away on rest.

  • Unstable angina: angina that can happen even on rest.

  • Typican & atypical angina: typical angina usually means a chest disconfort. However, looks like some people can experience other symptoms like nausea or shortness of breath. In these cases people talk about atypical angina. (Sources: Harrington Hospital, Wikipedia).

  • Thrombus: blood mass in solid state that hinders the blood flow in a blood vessel. (Source: MedlinePlus)

  • Embolus: thrombus that detatches and goes to other parts of the body. (Source: MedlinePlus)

  • Acute myocardial infarction: also known as heart attack, is the death of part of the heart tissue due to an ischemia. In other words, it is the death of part of the cells due to the lack of oxygen. (Sources: Healthline, Wikipedia)

  • Electrocardiogram: graph record of the electric signals that causes heart beats. Each part of the record of a normal heart beat has a name, the most interesting ones for this project are the T wave and the ST segment because they can give some information about the presence of issues like an ischemia or infarction. (Sources: Mayo Clinic, Wikipedia, (Spanish) Video sobre electrocardiograma de Alberto Sanagustín, (Spanish) Serie de videos sobre el electrocardiograma normal de Alberto Sangaustín)

  • Nuclear stress test: a radioactive dye is injected into the patient to see the blood flow on rest and doing exercise. Moreover, during this test the activity of the heart is also measured with an electrocardiogram. (Sources: Mayo Clinic, Healthline)

  • Asymptomatic disease: a disease that a patient has but they experience very few or no symptoms. (Sources: (Spanish) definicion.de, MayoClinic, Wikipedia)

  • Left ventricular hypertrophy: thickening of the walls of the main heart chamber that pumps the blood to the rest of the body. This can cause the muscle to loose elasticity which, in turns, causes the heart to not work properly. (Sources: Mayo Clinic)

0.- Libraries required

This is the list of libraries required for this hand on:

library(caret)
library(rpart.plot)
library(tidyverse)
library(dplyr)
library(knitr)
library(ggpubr)
library(skimr)
library(ggplot2)
library(gridExtra)
library(pheatmap)
library(rsample)
library(recipes)
library(GGally)
library(visdat)
library(glmnet)
library(precrec)
library(kableExtra)
library(patchwork)
library(visdat)

1.- The Data set and Exploratory data analysis (EDA)

setwd("~/Library/Mobile Documents/com~apple~CloudDocs/Master/Machine learning/sesion1")

data <- read.csv('heart_mod.csv')

# guardo el tema para poder usarlo en todas las figuras
MY_THEME <- theme(
        text = element_text(family = "Roboto"),
        axis.text.x = element_text(angle = 35, vjust = .6),
        axis.title.x = element_blank(),
        axis.ticks = element_blank(),
        axis.line = element_line(colour = "grey50"),
        panel.grid = element_line(color = "#b4aea9"),
        panel.grid.minor = element_blank(),
        panel.grid.major.x = element_blank(),
        panel.grid.major.y = element_line(linetype = "dashed"),
        panel.background = element_rect(fill = "#fbf9f4", color = "#fbf9f4"),
        plot.background = element_rect(fill = "#fbf9f4", color = "#fbf9f4"),
        legend.background = element_rect(fill = "#fbf9f4"),
        plot.title = element_text(
            family = "Roboto",
            size = 16,
            face = "bold",
            color = "#2a475e",
            margin = margin(b = 20)
        )
    )

The main goal of this step is to achieve a better understanding of what information each variable contains, as well as detecting possible errors. Some common examples are:

  • That a column has been stored with the wrong type: a numeric variable is being recognized as text or vice versa.

  • That a variable contains values that do not make sense.

1.1 Variable type

There are different functions in R that help us summarize the type of variables we have. The function glimpse, summary, or skim.

glimpse(data)
## Rows: 303
## Columns: 15
## $ X        <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18…
## $ age      <int> 63, 37, 41, 56, 57, 57, 56, 44, 52, 57, 54, 48, 49, 64, 58, 5…
## $ sex      <int> 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1…
## $ cp       <int> 3, 2, 1, 1, 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 3, 2, 2, 3, 0, 3, 0…
## $ trestbps <int> 145, 130, 130, 120, 120, 140, 140, 120, 172, 150, 140, 130, 1…
## $ chol     <int> 233, 250, 204, 236, 354, 192, 294, 263, 199, 168, 239, 275, 2…
## $ fbs      <int> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0…
## $ restecg  <int> 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1…
## $ thalach  <int> 150, 187, 172, 178, 163, 148, 153, 173, 162, 174, 160, 139, 1…
## $ exang    <int> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0…
## $ oldpeak  <dbl> 2.3, 3.5, 1.4, 0.8, 0.6, 0.4, 1.3, 0.0, 0.5, 1.6, 1.2, 0.2, 0…
## $ slope    <int> 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 0, 2, 2, 1…
## $ ca       <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0…
## $ thal     <int> 1, 2, 2, 2, 2, 1, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3…
## $ target   <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
summary(data)
##        X              age             sex               cp       
##  Min.   :  1.0   Min.   :29.00   Min.   :0.0000   Min.   :0.000  
##  1st Qu.: 76.5   1st Qu.:47.50   1st Qu.:0.0000   1st Qu.:0.000  
##  Median :152.0   Median :55.00   Median :1.0000   Median :1.000  
##  Mean   :152.0   Mean   :54.37   Mean   :0.6832   Mean   :0.967  
##  3rd Qu.:227.5   3rd Qu.:61.00   3rd Qu.:1.0000   3rd Qu.:2.000  
##  Max.   :303.0   Max.   :77.00   Max.   :1.0000   Max.   :3.000  
##                                                                  
##     trestbps          chol            fbs            restecg      
##  Min.   : 94.0   Min.   :126.0   Min.   :0.0000   Min.   :0.0000  
##  1st Qu.:120.0   1st Qu.:211.0   1st Qu.:0.0000   1st Qu.:0.0000  
##  Median :130.0   Median :240.0   Median :0.0000   Median :1.0000  
##  Mean   :131.6   Mean   :246.3   Mean   :0.1485   Mean   :0.5281  
##  3rd Qu.:140.0   3rd Qu.:274.5   3rd Qu.:0.0000   3rd Qu.:1.0000  
##  Max.   :200.0   Max.   :564.0   Max.   :1.0000   Max.   :2.0000  
##                                                                   
##     thalach          exang           oldpeak         slope      
##  Min.   : 71.0   Min.   :0.0000   Min.   :0.00   Min.   :0.000  
##  1st Qu.:133.5   1st Qu.:0.0000   1st Qu.:0.00   1st Qu.:1.000  
##  Median :153.0   Median :0.0000   Median :0.80   Median :1.000  
##  Mean   :149.6   Mean   :0.3267   Mean   :1.04   Mean   :1.399  
##  3rd Qu.:166.0   3rd Qu.:1.0000   3rd Qu.:1.60   3rd Qu.:2.000  
##  Max.   :202.0   Max.   :1.0000   Max.   :6.20   Max.   :2.000  
##                                                                 
##        ca              thal           target      
##  Min.   :0.0000   Min.   :1.000   Min.   :0.0000  
##  1st Qu.:0.0000   1st Qu.:2.000   1st Qu.:0.0000  
##  Median :0.0000   Median :2.000   Median :1.0000  
##  Mean   :0.6745   Mean   :2.329   Mean   :0.5446  
##  3rd Qu.:1.0000   3rd Qu.:3.000   3rd Qu.:1.0000  
##  Max.   :3.0000   Max.   :3.000   Max.   :1.0000  
##  NA's   :5        NA's   :2
skim(data)
Data summary
Name data
Number of rows 303
Number of columns 15
_______________________
Column type frequency:
numeric 15
________________________
Group variables None

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
X 0 1.00 152.00 87.61 1 76.5 152.0 227.5 303.0 ▇▇▇▇▇
age 0 1.00 54.37 9.08 29 47.5 55.0 61.0 77.0 ▁▆▇▇▁
sex 0 1.00 0.68 0.47 0 0.0 1.0 1.0 1.0 ▃▁▁▁▇
cp 0 1.00 0.97 1.03 0 0.0 1.0 2.0 3.0 ▇▃▁▅▁
trestbps 0 1.00 131.62 17.54 94 120.0 130.0 140.0 200.0 ▃▇▅▁▁
chol 0 1.00 246.26 51.83 126 211.0 240.0 274.5 564.0 ▃▇▂▁▁
fbs 0 1.00 0.15 0.36 0 0.0 0.0 0.0 1.0 ▇▁▁▁▂
restecg 0 1.00 0.53 0.53 0 0.0 1.0 1.0 2.0 ▇▁▇▁▁
thalach 0 1.00 149.65 22.91 71 133.5 153.0 166.0 202.0 ▁▂▅▇▂
exang 0 1.00 0.33 0.47 0 0.0 0.0 1.0 1.0 ▇▁▁▁▃
oldpeak 0 1.00 1.04 1.16 0 0.0 0.8 1.6 6.2 ▇▂▁▁▁
slope 0 1.00 1.40 0.62 0 1.0 1.0 2.0 2.0 ▁▁▇▁▇
ca 5 0.98 0.67 0.94 0 0.0 0.0 1.0 3.0 ▇▃▁▂▁
thal 2 0.99 2.33 0.58 1 2.0 2.0 3.0 3.0 ▁▁▇▁▆
target 0 1.00 0.54 0.50 0 0.0 1.0 1.0 1.0 ▇▁▁▁▇

1.2 Dataset features

This dataset is hosted on Kaggle (Heart Disease UCI), and it was from UCI Machine Learning Repository. There are records of about 300 patients from Cleveland and the features are described in a following section.

Attribute Information:

  1. age
  2. sex
  3. chest pain type (4 values)
  4. resting blood pressure
  5. serum cholestoral in mg/dl
  6. fbs: fasting blood sugar > 120 mg/dl

Hereon, variables are related to a nuclear stress test. That is, a stress test where a radioactive dye is also injected to the patient to see the blood flow:

  1. restecg: resting electrocardiographic results (values 0,1,2)
  2. thalach: maximum heart rate achieved
  3. exang: exercise induced angina 10.oldpeak: ST depression induced by exercise relative to rest
  4. slope: the slope of the peak exercise ST segment
  5. ca: number of major vessels (0-3) colored by flourosopy
  6. thal: 3 = normal; 6 = fixed defect; 7 = reversable defect
  7. target: 0 = yes; 1 = No

TASK TO DO:

  1. Remove X column.
  2. Transform categorical variable to R factors.
  3. Give (if necessary) a better name to the factor values (it will be helpful for the graphs).
# remove X column
try(data <- data %>% dplyr::select(-X))
# Transform categorical variable to R factors

data <- data %>%
    mutate(across(c(
        sex, cp, fbs, restecg, exang, slope, thal, target, ca
    ), as.factor))
# Give a better name to the factor values for the graphs
levels(data$sex) <- c("Female", "Male")
levels(data$cp) <- c("Asymptomatic", "Atypical angina", "No angina", "Typical angina")
levels(data$fbs) <- c("No", "Yes")
levels(data$restecg) <- c("Hypertrophy", "Normal", "Abnormalities")
levels(data$exang) <- c("No", "Yes")
levels(data$slope) <- c("Descending", "Flat", "Ascending")
levels(data$thal) <- c("Fixed defect", "Normal flow", "Reversible defect")
levels(data$target) <- c("Yes", "No")

Next step: Inspect all variables and make your hypotheses of how each variable affect heart attack incidence (target column).

1.2.1 Inspect variables: target

Target variable: whether the patient has a heart disease or not

  • Value 0: yes
  • Value 1: no

We can see that the distribution is quite balanced. Thanks to this it wouldn’t be a bad idea using accuracy to evaluate how well the models perform.

ggplot(data, aes(target, fill=target)) + 
  geom_bar(width = .6) +
  labs(x="Disease", y="Number of patients") +
  guides(fill="none") + MY_THEME

1.2.1 Inspect variables: age vs. target

Visualize how age affects the options of having or not having a heart attack.

options: density, boxplot. Check on google how to do it with ggplot.

data %>%
    ggplot(aes(x = target, y = age)) +
    geom_violin(width = .6) +
    geom_boxplot(width = .15) +
    geom_point(
        aes(col = target),
        position = position_jitterdodge(jitter.width = .3),
        alpha = .2,
        size = 3
    ) +
    theme_pubclean() +
    MY_THEME

1.2.2 Inspect variables: sex vs. target

Patient sex

  • Value 0: female
  • Value 1: male

options: bar charts with a number of cases, bar charts with the proportion of the number of cases in each group, and bar charts separated by classes.

data %>%
    ggplot(aes(x = sex, fill = target)) +
    geom_bar(width = .5) +
    MY_THEME

1.2.x Inspect variables: variable x vs. target

complete with the rest of variables.

Since in this case we only have 13 variables, it seems reasonable to go one by one.

There is also the ggpairs function to see how the variables are related.

In the event that we have many variables, this would not be feasible. To solve this, a useful tool is the heatmap for continuous variables.

numeric_variables <- which(unlist(lapply(data, is.numeric)))
pheatmap(cor(data[, numeric_variables[-6]]))

ggpairs(
    data = data,
    columns = 1:4,
    aes(color = target),
    legend = 1
) +
    theme_bw()

vars <- data %>% colnames()

my.plot <- function(var){
    if (class(data[, var]) == "factor") {
        p <- data %>%
            ggplot(aes_string(x = var, fill = "target")) +
            geom_bar(width = .5, alpha = .9) +
            ggtitle(paste0(var, " vs target"))
    } else {
        p <- data %>%
            ggplot(aes_string(x = "target", y = var)) +
            geom_violin(width = .6) +
            geom_boxplot(width = .15) +
            geom_point(
                aes_string(col = "target"),
                position = position_jitterdodge(jitter.width = .3),
                alpha = .2,
                size = .5
            ) +
            ggtitle(paste0(var, " vs target")) +
            guides(col = "none")
    }
    return(p)
}

p <- lapply(vars, my.plot)

p %>% 
    wrap_plots(ncol = 3, guides = "collect") +
    guide_area() &
    MY_THEME

Hipótesis extraídas de esta última figura:

  • age: se aprecia que la concentración de individuos con ataques de corazón es notablemente mayor a partir de los 50 años, por lo que la edad es un factor de riesgo.
  • sex: mientras que un porcentaje relativamente pequeño de mujeres han sufrido un infarto, más de la mitad de hombres lo han tenido, por lo que ser hombre es también un factor de riesgo.
  • cp: en cuanto al dolor de pecho, se ve cómo en la mayoría de casos los pacientes fueron asintomáticos, por lo que no parece un indicador significativo.
  • trestbps: hay un ligero skew hacia valores elevados en el grupo con ataque, por lo que puede que sea indicativo en valores altos, pero desde luego en la mayoría de casos la presión en reposo tenía valores normales.
  • chol: no parece que haya una diferencia en los valores de colesterol entre ambos grupos.
  • fbs: la mayoría de individuos tenían valores inferiores a 120 mg/dL (tanto con ataque como sin), y parece que en ambos casos el ratio es parecido (~50/50).

target: 0 = yes; 1 = No

data %>%
    dplyr::select(c(target, restecg)) %>%
    table() %>%
    kableExtra::kable()
Hypertrophy Normal Abnormalities
Yes 79 56 3
No 68 96 1
  • restecg: para valorar bien esta gráfica, he decidido ver la tabla de proporciones, ya que el número de casos con Abnormalities era muy bajo. De esta gráfica extraemos que tener hipertrtofia aumenta el riesgo de ataque frente a una situación normal, y que en el caso de tener anormalidades ese riesgo se dispara.
  • thalach: se ve cómo individuos sin ataque al corazón alcanzan valores siginificativamente mayores de pulso cardíaco.
  • exang: los idividuos que presentan angina inducida por ejercicio tienen una probabilidad mucho mayor de haber padecido un ataque al corazón.
  • oldpeak: individuos con valores elevados de ST depression muchas más posibilidades de haber sufrido ataque al corazón.
  • slope: individuos con pendiente descendente o plana tienen mucho mayor riesgo que aquellos con pendiente ascendente.
  • ca: a mayor número de vasos coloreados en fluoroscopia, mayor riesgo.
  • thal: tener defectos, tanto fijos como reversibles, aumenta severamente el riesgo.
  • target: la distribución de las clases está bastante equilibrada, por lo que es buena para entrenar un modelo de clasificación.

2. Data splitting

Given a fixed amount of data, typical recommendations for splitting your data into training-test splits include 60% (training)–40% (testing), 70%–30%, or 80%–20%. Generally speaking, these are appropriate guidelines to follow; however, it is good to keep the following points in mind:

  • Spending too much in training (e.g., >80%) won’t allow us to get a good assessment of predictive performance. We may find a model that fits the training data very well, but is not generalizable (overfitting).

  • Sometimes too much spent in testing (>40%) won’t allow us to get a good assessment of model parameters.

2.1 Random Sampling

Using the library rsample and its corresponding functions:

  • initial_split
  • training
  • testing

Remember to use the function set.seed in order to replicate the results.

set.seed(123) #important in order to replicate

split_basico <- initial_split(data, prop = .7)
sb_train <- training(split_basico)
sb_test <- testing(split_basico)

p1 <- sb_train %>%
    ggplot(aes(x = "target", fill = target)) +
    geom_bar(position = "fill", width = .4) +
    MY_THEME

p2 <- sb_test %>%
    ggplot(aes(x = "target", fill = target)) +
    geom_bar(position = "fill", width = .4) +
    MY_THEME

wrap_plots(list(p1, p2), guides = "collect")  # se aprecia un pequeño desbalance de las clases

kableExtra::kable(table(split_basico$data$target) %>% prop.table())
Var1 Freq
Yes 0.4554455
No 0.5445545
kableExtra::kable(table(sb_train$target) %>% prop.table())
Var1 Freq
Yes 0.4433962
No 0.5566038
kableExtra::kable(table(sb_test$target) %>% prop.table())
Var1 Freq
Yes 0.4835165
No 0.5164835

Se aprecia que el balance de las clases del target no es constante. Vamos a estratificar para garantizar que en test y train este balance se mantenga.

2.2 Stratified Sampling

If we want to explicitly control the sampling so that our training and test sets have similar Y distributions, we can use stratified sampling.

This is more common with classification problems where the response variable may be severely imbalanced (e.g., 90% of observations with response “Yes” and 10% with response “No”).

Check the help of the function initial_split to see how to do it.

We can use the functions table and prop.table to check if training and test sets have similar Y distributions.

set.seed(123) #important in order to replicate

split_strat <- initial_split(data, prop = .7, strata = "target")
strat_train <- training(split_strat)
strat_test <- testing(split_strat)

p1 <- strat_train %>%
    ggplot(aes(x = "target", fill = target)) +
    geom_bar(position = "fill", width = .4) +
    MY_THEME

p2 <- strat_test %>%
    ggplot(aes(x = "target", fill = target)) +
    geom_bar(position = "fill", width = .4) +
    MY_THEME

wrap_plots(list(p1, p2), guides = "collect")  # clases balanceadas

kableExtra::kable(table(split_strat$data$target) %>% prop.table())
Var1 Freq
Yes 0.4554455
No 0.5445545
kableExtra::kable(table(strat_train$target) %>% prop.table())
Var1 Freq
Yes 0.4549763
No 0.5450237
kableExtra::kable(table(strat_test$target) %>% prop.table())
Var1 Freq
Yes 0.4565217
No 0.5434783

Ahora sí, la proporción Yes/No es constante en todo el dataset, en test y en train.

3. Feature and Targetting engineering

3.1 Imputation of missing values

We can use the function vis_miss of the library visdat that provides a glance ggplot of the missingness inside a dataframe.

sum(is.na(data))
## [1] 7
vis_miss(data, cluster = TRUE)

Then, to impute missing values we can use the recipe R package.

It has the following steps:

  1. Function recipe: A recipe is a description of the steps to be applied to a data set in order to prepare it for data analysis.

  2. step_impute_xxxx: creates a specification of a recipe step that will substitute missing values: step_impute_mean creates a specification of a recipe step that will substitute missing values of numeric variables by the training set mean of those variables. step_impute_knn creates a specification of a recipe step that will impute missing data using nearest neighbors. Can be applied to both numeric and categorical variables.

Preparamos una receta para imputar los NAs en todos los predictores.

my_recipe <- recipe(target ~ ., data = strat_train) %>% 
    step_impute_knn(all_predictors())
  1. prep: For a recipe with at least one preprocessing operation, estimate the required parameters from a training set that can be later applied to other data sets.

Entrenamos los modelos KNN que se van a usar para imputar, usando los datos de train.

trained_recipe <- prep(my_recipe,training = strat_train)
  1. bake: For a recipe with at least one preprocessing operation that has been trained by prep, apply the computations to new data.

Finalmente imputamos los NAs con los modelos entrenados.

datos_train_prep <- bake(trained_recipe, new_data = strat_train)
datos_test_prep <- bake(trained_recipe, new_data = strat_test)
wrap_plots(list(vis_miss(datos_train_prep, cluster = T), vis_miss(datos_test_prep)), guides = "collect", nrow = 2)

Ahora vemos que no hay ningún NA.

4. Creation of the model

4.1 Logistic regression using glmnet

We are going to use the function glmnet. From its help page:

“Fit a generalized linear model via penalized maximum likelihood. The regularization path is computed for the lasso or elasticnet penalty at a grid of values for the regularization parameter lambda. Can deal with all shapes of data, including very large sparse data matrices. Fits linear, logistic and multinomial, poisson, and Cox regression models.

4.1.1 logistic regression (\(\lambda = 0\))

set.seed(123)

glmnet_traindata <- data.matrix(datos_train_prep[, -14])

fit_logistic_regression <-
    glmnet(
        x = glmnet_traindata,
        # it must be a matrix: use data.matrix
        y = datos_train_prep$target,
        family = "binomial",
        lambda = 0,
        intercept = TRUE
    )

4.1.2 logistic regression with penalty: lasso

set.seed(123)
fit_LogReg_cv_lasso <- cv.glmnet(x = glmnet_traindata,
                                 y = datos_train_prep$target,
                                 family = "binomial",
                                 nfold=10,
                                 alpha=1,
                                 type.measure = "auc")
plot(fit_LogReg_cv_lasso)

4.1.3 logistic regression with penaly: Ridge

fit_LogReg_cv_ridge <- cv.glmnet(x = glmnet_traindata,
                                 y = datos_train_prep$target,
                                 family = "binomial",
                                 nfold=10,
                                 alpha=0,
                                 type.measure = "auc")
plot(fit_LogReg_cv_ridge)

4.1.4 logistic regression with penaly: elasticnet

fit_LogReg_cv_en <- cv.glmnet(x = glmnet_traindata,
                                 y = datos_train_prep$target,
                                 family = "binomial",
                                 nfold=10,
                                 alpha=0.2,
                                 type.measure = "auc")
plot(fit_LogReg_cv_en)

4.2 Logistic regression using caret.

4.2.1 By default caret

Before training the model, we need to apply the function trainControl to specify some training parameters.

control_train <-
    trainControl(
        method = "cv",
        # which type: boost, cv, none of them.. etc
        number = 10,
        # number of folds or number of resampling iterations.
        returnResamp = "all",
        classProbs = TRUE,
        search = "grid",
        savePredictions = TRUE
    )

Once we have established the training characteristics, with the function train we train our model. This function “sets up a grid of tuning parameters for a number of classification and regression routines, fits each model and calculates a resampling based performance measure.”

set.seed(123)
modelo_glm_caret <- train(
    target ~ .,
    method = "glmnet",
    family = "binomial",
    trControl = control_train,
    data = datos_train_prep,
    # train data.
    metric = "Accuracy"
)

modelo_glm_caret
## glmnet 
## 
## 211 samples
##  13 predictor
##   2 classes: 'Yes', 'No' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 190, 189, 191, 191, 189, 190, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda        Accuracy   Kappa    
##   0.10   0.0005010359  0.8208009  0.6329692
##   0.10   0.0050103594  0.8103463  0.6107406
##   0.10   0.0501035939  0.8246537  0.6419031
##   0.55   0.0005010359  0.8158009  0.6214351
##   0.55   0.0050103594  0.8151082  0.6200658
##   0.55   0.0501035939  0.7878788  0.5685067
##   1.00   0.0005010359  0.8158009  0.6214351
##   1.00   0.0050103594  0.8151082  0.6200658
##   1.00   0.0501035939  0.7917316  0.5756382
## 
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0.1 and lambda = 0.05010359.

Let’s see some of the information we have in modelo_glm_caret:

plot(modelo_glm_caret) 

best combination of alpha and lambda:

modelo_glm_caret$bestTune
##   alpha     lambda
## 3   0.1 0.05010359

The result of each fold:

performance <- modelo_glm_caret$resample

kbl(performance) %>%
  kable_paper() %>%
  scroll_box(width = "100%", height = "200px")
alpha lambda Accuracy Kappa Resample
0.10 0.0501036 0.9047619 0.8000000 Fold01
0.10 0.0050104 0.9047619 0.8000000 Fold01
0.10 0.0005010 0.9047619 0.8000000 Fold01
0.55 0.0501036 0.8571429 0.6956522 Fold01
0.55 0.0050104 0.9047619 0.8000000 Fold01
0.55 0.0005010 0.9047619 0.8000000 Fold01
1.00 0.0501036 0.9047619 0.8000000 Fold01
1.00 0.0050104 0.9047619 0.8000000 Fold01
1.00 0.0005010 0.9047619 0.8000000 Fold01
0.10 0.0501036 0.7727273 0.5378151 Fold02
0.10 0.0050104 0.6818182 0.3529412 Fold02
0.10 0.0005010 0.6818182 0.3529412 Fold02
0.55 0.0501036 0.7272727 0.4500000 Fold02
0.55 0.0050104 0.6818182 0.3529412 Fold02
0.55 0.0005010 0.6818182 0.3529412 Fold02
1.00 0.0501036 0.7272727 0.4500000 Fold02
1.00 0.0050104 0.6818182 0.3529412 Fold02
1.00 0.0005010 0.6818182 0.3529412 Fold02
0.10 0.0501036 1.0000000 1.0000000 Fold03
0.10 0.0050104 1.0000000 1.0000000 Fold03
0.10 0.0005010 1.0000000 1.0000000 Fold03
0.55 0.0501036 1.0000000 1.0000000 Fold03
0.55 0.0050104 1.0000000 1.0000000 Fold03
0.55 0.0005010 1.0000000 1.0000000 Fold03
1.00 0.0501036 1.0000000 1.0000000 Fold03
1.00 0.0050104 1.0000000 1.0000000 Fold03
1.00 0.0005010 1.0000000 1.0000000 Fold03
0.10 0.0501036 0.6000000 0.1578947 Fold04
0.10 0.0050104 0.5500000 0.0425532 Fold04
0.10 0.0005010 0.6000000 0.1578947 Fold04
0.55 0.0501036 0.6500000 0.2553191 Fold04
0.55 0.0050104 0.5500000 0.0425532 Fold04
0.55 0.0005010 0.5500000 0.0425532 Fold04
1.00 0.0501036 0.6500000 0.2553191 Fold04
1.00 0.0050104 0.5500000 0.0425532 Fold04
1.00 0.0005010 0.5500000 0.0425532 Fold04
0.10 0.0501036 0.7727273 0.5378151 Fold05
0.10 0.0050104 0.7727273 0.5378151 Fold05
0.10 0.0005010 0.7272727 0.4406780 Fold05
0.55 0.0501036 0.7272727 0.4500000 Fold05
0.55 0.0050104 0.7727273 0.5378151 Fold05
0.55 0.0005010 0.7272727 0.4406780 Fold05
1.00 0.0501036 0.7727273 0.5378151 Fold05
1.00 0.0050104 0.7727273 0.5378151 Fold05
1.00 0.0005010 0.7272727 0.4406780 Fold05
0.10 0.0501036 0.7142857 0.4220183 Fold06
0.10 0.0050104 0.6666667 0.3287671 Fold06
0.10 0.0005010 0.6666667 0.3287671 Fold06
0.55 0.0501036 0.6666667 0.3225806 Fold06
0.55 0.0050104 0.7142857 0.4220183 Fold06
0.55 0.0005010 0.6666667 0.3287671 Fold06
1.00 0.0501036 0.6666667 0.3225806 Fold06
1.00 0.0050104 0.7142857 0.4220183 Fold06
1.00 0.0005010 0.6666667 0.3287671 Fold06
0.10 0.0501036 0.8500000 0.7000000 Fold07
0.10 0.0050104 0.8500000 0.6938776 Fold07
0.10 0.0005010 0.9500000 0.8979592 Fold07
0.55 0.0501036 0.8500000 0.7000000 Fold07
0.55 0.0050104 0.8500000 0.6938776 Fold07
0.55 0.0005010 0.9500000 0.8979592 Fold07
1.00 0.0501036 0.7500000 0.4897959 Fold07
1.00 0.0050104 0.8500000 0.6938776 Fold07
1.00 0.0005010 0.9500000 0.8979592 Fold07
0.10 0.0501036 0.8636364 0.7226891 Fold08
0.10 0.0050104 0.8636364 0.7226891 Fold08
0.10 0.0005010 0.8636364 0.7226891 Fold08
0.55 0.0501036 0.8181818 0.6333333 Fold08
0.55 0.0050104 0.8636364 0.7226891 Fold08
0.55 0.0005010 0.8636364 0.7226891 Fold08
1.00 0.0501036 0.8636364 0.7226891 Fold08
1.00 0.0050104 0.8636364 0.7226891 Fold08
1.00 0.0005010 0.8636364 0.7226891 Fold08
0.10 0.0501036 0.9047619 0.8090909 Fold09
0.10 0.0050104 0.9047619 0.8090909 Fold09
0.10 0.0005010 0.9047619 0.8090909 Fold09
0.55 0.0501036 0.8095238 0.6181818 Fold09
0.55 0.0050104 0.9047619 0.8090909 Fold09
0.55 0.0005010 0.9047619 0.8090909 Fold09
1.00 0.0501036 0.8095238 0.6181818 Fold09
1.00 0.0050104 0.9047619 0.8090909 Fold09
1.00 0.0005010 0.9047619 0.8090909 Fold09
0.10 0.0501036 0.8636364 0.7317073 Fold10
0.10 0.0050104 0.9090909 0.8196721 Fold10
0.10 0.0005010 0.9090909 0.8196721 Fold10
0.55 0.0501036 0.7727273 0.5600000 Fold10
0.55 0.0050104 0.9090909 0.8196721 Fold10
0.55 0.0005010 0.9090909 0.8196721 Fold10
1.00 0.0501036 0.7727273 0.5600000 Fold10
1.00 0.0050104 0.9090909 0.8196721 Fold10
1.00 0.0005010 0.9090909 0.8196721 Fold10

We can plot the results:

performance$alpha <- as.factor(performance$alpha)
performance$lambda <- as.factor(performance$lambda)
ggplot(data = performance, aes(x = alpha, y = Accuracy,color=lambda)) +
  geom_boxplot() +
  geom_point(position=position_jitterdodge())+
  labs(x = "") +
  theme_bw() 

4.2.2 tunning caret

control_train <-
    trainControl(
        method = "cv",
        # which type: boost, cv, none of them.. etc
        number = 10,
        # number of folds or number of resampling iterations.
        returnResamp = "all",
        classProbs = TRUE,
        summaryFunction = twoClassSummary,
        #a function to compute performance metrics across resamples.
        search = "grid",
        savePredictions = TRUE
    )

How can we do a custom hyperparameter search? With the help of the function expand.grid and with the parameter tuneGrid of the function train:

lambda <- c(0,0.01, 0.1)
alpha <- c(0,0.1,0.3,0.5, 0.9, 1)
hyper_grid <- expand.grid(alpha = alpha, lambda = lambda)

set.seed(123)
modelo_glm_caret_grid <- train(
    target ~ .,
    method = "glmnet",
    family = "binomial",
    trControl = control_train,
    data = datos_train_prep,
    tuneGrid = hyper_grid,
    metric = "ROC"
)

modelo_glm_caret_grid
## glmnet 
## 
## 211 samples
##  13 predictor
##   2 classes: 'Yes', 'No' 
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 190, 189, 191, 191, 189, 190, ... 
## Resampling results across tuning parameters:
## 
##   alpha  lambda  ROC        Sens       Spec     
##   0.0    0.00    0.8829461  0.7588889  0.8780303
##   0.0    0.01    0.8829461  0.7588889  0.8780303
##   0.0    0.10    0.8871886  0.7700000  0.8613636
##   0.1    0.00    0.8674411  0.7500000  0.8787879
##   0.1    0.01    0.8737542  0.7377778  0.8787879
##   0.1    0.10    0.8871801  0.7600000  0.8530303
##   0.3    0.00    0.8666077  0.7500000  0.8787879
##   0.3    0.01    0.8754966  0.7377778  0.8787879
##   0.3    0.10    0.8774579  0.7500000  0.8363636
##   0.5    0.00    0.8674411  0.7500000  0.8787879
##   0.5    0.01    0.8764057  0.7377778  0.8696970
##   0.5    0.10    0.8680640  0.7277778  0.8363636
##   0.9    0.00    0.8674411  0.7500000  0.8787879
##   0.9    0.01    0.8727189  0.7477778  0.8613636
##   0.9    0.10    0.8305724  0.7066667  0.8000000
##   1.0    0.00    0.8674411  0.7500000  0.8787879
##   1.0    0.01    0.8736448  0.7477778  0.8613636
##   1.0    0.10    0.8277189  0.6855556  0.8000000
## 
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were alpha = 0 and lambda = 0.1.
plot(modelo_glm_caret_grid)

modelo_glm_caret_grid$bestTune
##   alpha lambda
## 3     0    0.1
performance <- modelo_glm_caret_grid$resample
kbl(performance) %>%
    kable_paper() %>%
    scroll_box(width = "100%", height = "200px")
alpha lambda ROC Sens Spec Resample
0.0 0.10 0.9537037 0.7777778 1.0000000 Fold01
0.0 0.00 0.9537037 0.7777778 1.0000000 Fold01
0.0 0.01 0.9537037 0.7777778 1.0000000 Fold01
0.1 0.10 0.9629630 0.7777778 1.0000000 Fold01
0.1 0.00 0.9537037 0.7777778 1.0000000 Fold01
0.1 0.01 0.9537037 0.7777778 1.0000000 Fold01
0.3 0.10 0.9629630 0.8888889 1.0000000 Fold01
0.3 0.00 0.9537037 0.7777778 1.0000000 Fold01
0.3 0.01 0.9537037 0.7777778 1.0000000 Fold01
0.5 0.10 0.9629630 0.7777778 1.0000000 Fold01
0.5 0.00 0.9537037 0.7777778 1.0000000 Fold01
0.5 0.01 0.9537037 0.7777778 1.0000000 Fold01
0.9 0.10 0.9537037 0.6666667 1.0000000 Fold01
0.9 0.00 0.9537037 0.7777778 1.0000000 Fold01
0.9 0.01 0.9537037 0.7777778 1.0000000 Fold01
1.0 0.10 0.9537037 0.5555556 1.0000000 Fold01
1.0 0.00 0.9537037 0.7777778 1.0000000 Fold01
1.0 0.01 0.9629630 0.7777778 1.0000000 Fold01
0.0 0.10 0.7500000 0.7000000 0.7500000 Fold02
0.0 0.00 0.7333333 0.7000000 0.8333333 Fold02
0.0 0.01 0.7333333 0.7000000 0.8333333 Fold02
0.1 0.10 0.7416667 0.7000000 0.7500000 Fold02
0.1 0.00 0.7166667 0.6000000 0.7500000 Fold02
0.1 0.01 0.7083333 0.6000000 0.7500000 Fold02
0.3 0.10 0.7250000 0.7000000 0.7500000 Fold02
0.3 0.00 0.7166667 0.6000000 0.7500000 Fold02
0.3 0.01 0.7166667 0.6000000 0.7500000 Fold02
0.5 0.10 0.7166667 0.7000000 0.7500000 Fold02
0.5 0.00 0.7166667 0.6000000 0.7500000 Fold02
0.5 0.01 0.7166667 0.6000000 0.7500000 Fold02
0.9 0.10 0.6916667 0.7000000 0.7500000 Fold02
0.9 0.00 0.7166667 0.6000000 0.7500000 Fold02
0.9 0.01 0.7083333 0.7000000 0.7500000 Fold02
1.0 0.10 0.7083333 0.7000000 0.7500000 Fold02
1.0 0.00 0.7166667 0.6000000 0.7500000 Fold02
1.0 0.01 0.7083333 0.7000000 0.7500000 Fold02
0.0 0.10 1.0000000 1.0000000 1.0000000 Fold03
0.0 0.00 1.0000000 1.0000000 1.0000000 Fold03
0.0 0.01 1.0000000 1.0000000 1.0000000 Fold03
0.1 0.10 1.0000000 1.0000000 1.0000000 Fold03
0.1 0.00 1.0000000 1.0000000 1.0000000 Fold03
0.1 0.01 1.0000000 1.0000000 1.0000000 Fold03
0.3 0.10 1.0000000 1.0000000 1.0000000 Fold03
0.3 0.00 1.0000000 1.0000000 1.0000000 Fold03
0.3 0.01 1.0000000 1.0000000 1.0000000 Fold03
0.5 0.10 1.0000000 1.0000000 1.0000000 Fold03
0.5 0.00 1.0000000 1.0000000 1.0000000 Fold03
0.5 0.01 1.0000000 1.0000000 1.0000000 Fold03
0.9 0.10 1.0000000 1.0000000 1.0000000 Fold03
0.9 0.00 1.0000000 1.0000000 1.0000000 Fold03
0.9 0.01 1.0000000 1.0000000 1.0000000 Fold03
1.0 0.10 1.0000000 1.0000000 1.0000000 Fold03
1.0 0.00 1.0000000 1.0000000 1.0000000 Fold03
1.0 0.01 1.0000000 1.0000000 1.0000000 Fold03
0.0 0.10 0.8080808 0.3333333 0.8181818 Fold04
0.0 0.00 0.7878788 0.3333333 0.8181818 Fold04
0.0 0.01 0.7878788 0.3333333 0.8181818 Fold04
0.1 0.10 0.8080808 0.3333333 0.9090909 Fold04
0.1 0.00 0.6666667 0.3333333 0.8181818 Fold04
0.1 0.01 0.7272727 0.2222222 0.8181818 Fold04
0.3 0.10 0.7777778 0.3333333 0.9090909 Fold04
0.3 0.00 0.6666667 0.3333333 0.8181818 Fold04
0.3 0.01 0.7272727 0.2222222 0.8181818 Fold04
0.5 0.10 0.7373737 0.3333333 0.9090909 Fold04
0.5 0.00 0.6666667 0.3333333 0.8181818 Fold04
0.5 0.01 0.7272727 0.2222222 0.8181818 Fold04
0.9 0.10 0.7070707 0.3333333 0.7272727 Fold04
0.9 0.00 0.6666667 0.3333333 0.8181818 Fold04
0.9 0.01 0.7070707 0.2222222 0.8181818 Fold04
1.0 0.10 0.7070707 0.3333333 0.7272727 Fold04
1.0 0.00 0.6666667 0.3333333 0.8181818 Fold04
1.0 0.01 0.7070707 0.2222222 0.8181818 Fold04
0.0 0.10 0.8166667 0.7000000 0.8333333 Fold05
0.0 0.00 0.8250000 0.7000000 0.8333333 Fold05
0.0 0.01 0.8250000 0.7000000 0.8333333 Fold05
0.1 0.10 0.8166667 0.7000000 0.8333333 Fold05
0.1 0.00 0.8166667 0.6000000 0.8333333 Fold05
0.1 0.01 0.8166667 0.7000000 0.8333333 Fold05
0.3 0.10 0.7833333 0.7000000 0.7500000 Fold05
0.3 0.00 0.8083333 0.6000000 0.8333333 Fold05
0.3 0.01 0.8166667 0.7000000 0.8333333 Fold05
0.5 0.10 0.7666667 0.7000000 0.7500000 Fold05
0.5 0.00 0.8166667 0.6000000 0.8333333 Fold05
0.5 0.01 0.8166667 0.7000000 0.8333333 Fold05
0.9 0.10 0.7583333 0.7000000 0.7500000 Fold05
0.9 0.00 0.8166667 0.6000000 0.8333333 Fold05
0.9 0.01 0.8083333 0.7000000 0.8333333 Fold05
1.0 0.10 0.7416667 0.6000000 0.6666667 Fold05
1.0 0.00 0.8166667 0.6000000 0.8333333 Fold05
1.0 0.01 0.8083333 0.7000000 0.8333333 Fold05
0.0 0.10 0.7545455 0.6000000 0.8181818 Fold06
0.0 0.00 0.7363636 0.6000000 0.8181818 Fold06
0.0 0.01 0.7363636 0.6000000 0.8181818 Fold06
0.1 0.10 0.7636364 0.6000000 0.8181818 Fold06
0.1 0.00 0.7181818 0.6000000 0.7272727 Fold06
0.1 0.01 0.7181818 0.6000000 0.8181818 Fold06
0.3 0.10 0.7818182 0.5000000 0.8181818 Fold06
0.3 0.00 0.7181818 0.6000000 0.7272727 Fold06
0.3 0.01 0.7181818 0.6000000 0.8181818 Fold06
0.5 0.10 0.7909091 0.5000000 0.8181818 Fold06
0.5 0.00 0.7181818 0.6000000 0.7272727 Fold06
0.5 0.01 0.7272727 0.6000000 0.8181818 Fold06
0.9 0.10 0.7090909 0.4000000 0.7272727 Fold06
0.9 0.00 0.7181818 0.6000000 0.7272727 Fold06
0.9 0.01 0.7272727 0.6000000 0.8181818 Fold06
1.0 0.10 0.7090909 0.4000000 0.7272727 Fold06
1.0 0.00 0.7181818 0.6000000 0.7272727 Fold06
1.0 0.01 0.7272727 0.6000000 0.8181818 Fold06
0.0 0.10 0.9191919 0.8888889 0.8181818 Fold07
0.0 0.00 0.9393939 0.7777778 0.8181818 Fold07
0.0 0.01 0.9393939 0.7777778 0.8181818 Fold07
0.1 0.10 0.9090909 0.8888889 0.8181818 Fold07
0.1 0.00 0.9494949 0.8888889 1.0000000 Fold07
0.1 0.01 0.9595960 0.7777778 0.9090909 Fold07
0.3 0.10 0.8989899 0.7777778 0.8181818 Fold07
0.3 0.00 0.9494949 0.8888889 1.0000000 Fold07
0.3 0.01 0.9595960 0.7777778 0.9090909 Fold07
0.5 0.10 0.8787879 0.6666667 0.8181818 Fold07
0.5 0.00 0.9494949 0.8888889 1.0000000 Fold07
0.5 0.01 0.9595960 0.7777778 0.8181818 Fold07
0.9 0.10 0.7676768 0.6666667 0.7272727 Fold07
0.9 0.00 0.9494949 0.8888889 1.0000000 Fold07
0.9 0.01 0.9595960 0.7777778 0.8181818 Fold07
1.0 0.10 0.7474747 0.6666667 0.6363636 Fold07
1.0 0.00 0.9494949 0.8888889 1.0000000 Fold07
1.0 0.01 0.9595960 0.7777778 0.8181818 Fold07
0.0 0.10 0.9666667 0.8000000 0.8333333 Fold08
0.0 0.00 0.9583333 0.8000000 0.9166667 Fold08
0.0 0.01 0.9583333 0.8000000 0.9166667 Fold08
0.1 0.10 0.9666667 0.8000000 0.8333333 Fold08
0.1 0.00 0.9666667 0.8000000 0.9166667 Fold08
0.1 0.01 0.9583333 0.8000000 0.9166667 Fold08
0.3 0.10 0.9666667 0.8000000 0.9166667 Fold08
0.3 0.00 0.9666667 0.8000000 0.9166667 Fold08
0.3 0.01 0.9583333 0.8000000 0.9166667 Fold08
0.5 0.10 0.9666667 0.8000000 0.9166667 Fold08
0.5 0.00 0.9666667 0.8000000 0.9166667 Fold08
0.5 0.01 0.9583333 0.8000000 0.9166667 Fold08
0.9 0.10 0.8666667 0.8000000 0.9166667 Fold08
0.9 0.00 0.9666667 0.8000000 0.9166667 Fold08
0.9 0.01 0.9583333 0.8000000 0.9166667 Fold08
1.0 0.10 0.8583333 0.8000000 0.9166667 Fold08
1.0 0.00 0.9666667 0.8000000 0.9166667 Fold08
1.0 0.01 0.9583333 0.8000000 0.9166667 Fold08
0.0 0.10 0.9363636 0.9000000 0.9090909 Fold09
0.0 0.00 0.9454545 0.9000000 0.9090909 Fold09
0.0 0.01 0.9454545 0.9000000 0.9090909 Fold09
0.1 0.10 0.9363636 0.8000000 0.8181818 Fold09
0.1 0.00 0.9363636 0.9000000 0.9090909 Fold09
0.1 0.01 0.9454545 0.9000000 0.9090909 Fold09
0.3 0.10 0.9363636 0.8000000 0.8181818 Fold09
0.3 0.00 0.9363636 0.9000000 0.9090909 Fold09
0.3 0.01 0.9545455 0.9000000 0.9090909 Fold09
0.5 0.10 0.9272727 0.8000000 0.8181818 Fold09
0.5 0.00 0.9363636 0.9000000 0.9090909 Fold09
0.5 0.01 0.9545455 0.9000000 0.9090909 Fold09
0.9 0.10 0.9181818 0.8000000 0.8181818 Fold09
0.9 0.00 0.9363636 0.9000000 0.9090909 Fold09
0.9 0.01 0.9545455 0.9000000 0.9090909 Fold09
1.0 0.10 0.9181818 0.8000000 0.9090909 Fold09
1.0 0.00 0.9363636 0.9000000 0.9090909 Fold09
1.0 0.01 0.9545455 0.9000000 0.9090909 Fold09
0.0 0.10 0.9666667 1.0000000 0.8333333 Fold10
0.0 0.00 0.9500000 1.0000000 0.8333333 Fold10
0.0 0.01 0.9500000 1.0000000 0.8333333 Fold10
0.1 0.10 0.9666667 1.0000000 0.7500000 Fold10
0.1 0.00 0.9500000 1.0000000 0.8333333 Fold10
0.1 0.01 0.9500000 1.0000000 0.8333333 Fold10
0.3 0.10 0.9416667 1.0000000 0.5833333 Fold10
0.3 0.00 0.9500000 1.0000000 0.8333333 Fold10
0.3 0.01 0.9500000 1.0000000 0.8333333 Fold10
0.5 0.10 0.9333333 1.0000000 0.5833333 Fold10
0.5 0.00 0.9500000 1.0000000 0.8333333 Fold10
0.5 0.01 0.9500000 1.0000000 0.8333333 Fold10
0.9 0.10 0.9333333 1.0000000 0.5833333 Fold10
0.9 0.00 0.9500000 1.0000000 0.8333333 Fold10
0.9 0.01 0.9500000 1.0000000 0.7500000 Fold10
1.0 0.10 0.9333333 1.0000000 0.6666667 Fold10
1.0 0.00 0.9500000 1.0000000 0.8333333 Fold10
1.0 0.01 0.9500000 1.0000000 0.7500000 Fold10
performance$alpha <- as.factor(performance$alpha)
performance$lambda <- as.factor(performance$lambda)
ggplot(data = performance, aes(x = alpha, y = ROC,color=lambda)) +
    geom_boxplot() +
    geom_point(position=position_jitterdodge())+
    labs(x = "") +
    theme_bw() 

5. Test and compare models

Use the library precrec to compare the models.

5.1 train

#for models created with glmnet
pred_fit2 <- predict(
    object = fit_logistic_regression,
    newx = data.matrix(datos_train_prep[, -14])
)

pred_fit3 <- predict(
    object = fit_LogReg_cv_lasso,
    newx = data.matrix(datos_train_prep[, -14]),
    s = "lambda.min"
)

pred_fit4 <- predict(
    object = fit_LogReg_cv_ridge,
    newx = data.matrix(datos_train_prep[, -14]),
    s = "lambda.min"
)

pred_fit5 <- predict(
    object = fit_LogReg_cv_en,
    newx = data.matrix(datos_train_prep[, -14]),
    s = "lambda.min"
)


#for models created with caret:
pred_caret <- predict(
    object = modelo_glm_caret,
    newdata =  datos_train_prep[-14],
    type = "prob"
)
pred_fit6 <- pred_caret$No  # la clase 0 en este caso es que si tenga ataque


pred_caret <- predict(
    object = modelo_glm_caret_grid,
    newdata =  datos_train_prep[-14],
    type = "prob"
)
pred_fit7 <- pred_caret$No 
labels <- (as.vector(datos_train_prep$target) == "No") + 0 #we are predicting not to have a heart attack
mis_models <- mmdata(
    list(
        as.vector(pred_fit2),
        as.vector(pred_fit3),
        as.vector(pred_fit4),
        as.vector(pred_fit5),
        as.vector(pred_fit6),
        as.vector(pred_fit7)
    ),
    labels,
    modnames = c(
        "logistic_regression",
        "lasso",
        "ridge",
        "elastic-net",
        "caret",
        "caret_grid"
    )
)
auroc <- evalmod(mis_models)
autoplot(auroc)

5.2 test

Predictions:

#for models created with glmnet
pred_fit2 <- predict(
    object = fit_logistic_regression,
    newx = data.matrix(datos_test_prep[, -14])
)

pred_fit3 <- predict(
    object = fit_LogReg_cv_lasso,
    newx = data.matrix(datos_test_prep[, -14]),
    s = "lambda.min"
)

pred_fit4 <- predict(
    object = fit_LogReg_cv_ridge,
    newx = data.matrix(datos_test_prep[, -14]),
    s = "lambda.min"
)

pred_fit5 <- predict(
    object = fit_LogReg_cv_en,
    newx = data.matrix(datos_test_prep[, -14]),
    s = "lambda.min"
)


#for models created with caret:
pred_caret <- predict(
    object = modelo_glm_caret,
    newdata =  datos_test_prep[-14],
    type = "prob"
)
pred_fit6 <- pred_caret$No  # la clase 0 en este caso es que si tenga ataque


pred_caret <- predict(
    object = modelo_glm_caret_grid,
    newdata =  datos_test_prep[-14],
    type = "prob"
)
pred_fit7 <- pred_caret$No 

Plots final:

labels <- (as.vector(datos_test_prep$target) == "No") + 0 #we are predicting not to have a heart attack
mis_models <- mmdata(
    list(
        as.vector(pred_fit2),
        as.vector(pred_fit3),
        as.vector(pred_fit4),
        as.vector(pred_fit5),
        as.vector(pred_fit6),
        as.vector(pred_fit7)
    ),
    labels,
    modnames = c(
        "logistic_regression",
        "lasso",
        "ridge",
        "elastic-net",
        "caret",
        "caret_grid"
    )
)
auroc <- evalmod(mis_models)
autoplot(auroc)